import numpy as np
from scipy import sparse
from scipy import linalg as scila
from scipy.sparse import linalg as spla
from scipy.sparse.linalg.eigen.arpack import ArpackNoConvergence
from variables import (Nu, Ny, Nv, dim, ku_grid, ky_grid, kv_grid, kx_grid_original, gridN, tSO, t0, ty, sigma_1,
                       sigma_2, sigma_3, sigma_4)
import Hamiltonians_No_psi
import time
import functools
import multiprocessing


def diag_test(ii):
    return ii * ii


def diag_k(ii, Grid, hspin, num_of_band):
    (hspin1, hspin2, hspin3, hspin4) = hspin
    (ku, ky, kv) = Grid
    hp1, hp4 = Hamiltonians_No_psi.Hp(ku[ii], ky[ii], kv[ii])
    H = sparse.kron(sigma_1, hp1 + hspin1) + sparse.kron(sigma_2, hspin2) \
        + sparse.kron(sigma_3, hspin3) + sparse.kron(sigma_4, hp4 + hspin4)
    envl, envc = spla.eigsh(H, k=num_of_band, which='LM', sigma=-5)
    sort_res = np.argsort(envl)
    energies = envl[sort_res].reshape((num_of_band, 1))
    states = envc[:, sort_res]
    Psi_Up = states[0:int(dim / 2), :]
    Psi_Dn = states[int(dim / 2):dim, :]
    ESigmaZ = np.diag(Psi_Up.conj().T @ Psi_Up - Psi_Dn.conj().T @ Psi_Dn)
    ESigmaZ = ESigmaZ.real.reshape((num_of_band, 1))
    res_k = np.hstack((energies, ESigmaZ))
    return res_k


def diagpar(mz, num_of_band, psi):
    energies = np.zeros((num_of_band, Nu * Ny * Nv), dtype=np.float64)
    # states = np.zeros((dim, num_of_band, Nu*Ny*Nv), dtype=np.complex128)
    # ESigmaX = np.zeros((num_of_band, Nu*Ny*Nv), dtype=np.complex128)
    # ESigmaY = np.zeros((num_of_band, Nu*Ny*Nv), dtype=np.complex128)
    # ESigmaZ = np.zeros((num_of_band, Nu*Ny*Nv), dtype=np.float64)
    hspin = Hamiltonians_No_psi.Hspin(mz, psi)
    ku, ky, kv = np.meshgrid(ku_grid, ky_grid, kv_grid, indexing='ij')
    ku = np.reshape(ku, (np.size(ku), 1), order='F')
    ky = np.reshape(ky, (np.size(ky), 1), order='F')
    kv = np.reshape(kv, (np.size(kv), 1), order='F')
    Grid = (ku, ky, kv)
    diag_task = functools.partial(diag_k, Grid=Grid, hspin=hspin, num_of_band=num_of_band)
    Cal_Type = 'pool'
    if Cal_Type == 'pool':
        multiprocessing.freeze_support()
        pool = multiprocessing.Pool(processes=3)
        start_time = time.time()
        res = np.array(list(pool.map(diag_task, range(Nu * Ny * Nv))))
        # ress = []
        # ress_got = []
        # for ii in range(10):
        #    res = pool.apply_async(diag_test, args=(ii,))
        #    print(ii)
        #    ress.append(res)
        # pool.close()
        # pool.join()
        # ress_got = []
        # for res in ress:
        #    ress_got.append(res.get())
        print("Time total diagpar pool:", time.time() - start_time)
    else:
        start_time = time.time()
        res = np.array(list(map(diag_task, range(Nu * Ny * Nv))))
        print("Time total diagpar map:", time.time() - start_time)
    energies = res[:, :, 0]
    ESigmaZ = res[:, :, 1]
    energies = np.reshape(energies, (num_of_band, Nu, Ny, Nv), order='F')
    ESigmaZ = np.reshape(ESigmaZ, (num_of_band, Nu, Ny, Nv), order='F')
    #    states = np.array(res)[:, :, 2:]
    return energies, 0, ESigmaZ, ESigmaZ, ESigmaZ


def diag(mz, num_of_band, t1, t2, t3, t4):
    energies = np.zeros((num_of_band, Nu * Ny * Nv), dtype=np.float64)
    states = np.zeros((dim, num_of_band, Nu * Ny * Nv), dtype=np.complex128)
    # ESigmaX = np.zeros((num_of_band, Nu*Ny*Nv), dtype=np.complex128)
    # ESigmaY = np.zeros((num_of_band, Nu*Ny*Nv), dtype=np.complex128)
    ESigmaZ = np.zeros((num_of_band, Nu * Ny * Nv), dtype=np.float64)
    hspin1, hspin2, hspin3, hspin4 = Hamiltonians_No_psi.Hspin(mz, t1, t2, t3, t4)
    ku, ky, kv = np.meshgrid(ku_grid, ky_grid, kv_grid, indexing='ij')
    ku = np.reshape(ku, (np.size(ku), 1), order='F')
    ky = np.reshape(ky, (np.size(ky), 1), order='F')
    kv = np.reshape(kv, (np.size(kv), 1), order='F')
    start_time = time.time()
    for ii in range(0, Nu * Ny * Nv):
        if np.mod(ii, 5000) == 0:
            print(ii / 5000)
        hp1, hp4 = Hamiltonians_No_psi.Hp(ku[ii], ky[ii], kv[ii])
        H = sparse.kron(sigma_1, hp1 + hspin1) + sparse.kron(sigma_2, hspin2) \
            + sparse.kron(sigma_3, hspin3) + sparse.kron(sigma_4, hp4 + hspin4)
        envl, envc = spla.eigsh(H, k=num_of_band, which='LM', sigma=-5)
        sort_res = np.argsort(envl)
        energies[0:num_of_band, ii] = envl[sort_res]
        states[:, 0:num_of_band, ii] = envc[:, sort_res]
        Psi_Up = states[0:int(dim / 2), 0:num_of_band, ii]
        Psi_Dn = states[int(dim / 2):dim, 0:num_of_band, ii]
        # ESigmaX[0:num_of_band, ii] = np.diag(np.dot(Psi_Up.conj().T, Psi_Dn)+np.dot(Psi_Dn.conj().T, Psi_Up))
        # ESigmaY[0:num_of_band, ii] = np.diag(1j*np.dot(Psi_Dn.conj().T, Psi_Up)-1j*np.dot(Psi_Up.conj().T, Psi_Dn))
        # ESigmaZ[0:num_of_band, ii] = np.diag(np.dot(Psi_Up.conj().T, Psi_Up)-np.dot(Psi_Dn.conj().T, Psi_Dn))
        ESigmaZ[0:num_of_band, ii] = np.real(np.diag(Psi_Up.conj().T @ Psi_Up - Psi_Dn.conj().T @ Psi_Dn))
        te = np.imag(np.diag(Psi_Up.conj().T @ Psi_Up - Psi_Dn.conj().T @ Psi_Dn))
        if len(te.nonzero()[0]) != 0:
            input("Wrong!")
    print("Time total diag:", time.time() - start_time)
    energies = np.reshape(energies, (num_of_band, Nu, Ny, Nv), order='F')
    states = np.reshape(states, (dim, num_of_band, Nu, Ny, Nv), order='F')
    # ESigmaX = np.reshape(ESigmaX, (num_of_band, Nu, Ny, Nv), order='F')
    # ESigmaY = np.reshape(ESigmaY, (num_of_band, Nu, Ny, Nv), order='F')
    ESigmaZ = np.reshape(ESigmaZ, (num_of_band, Nu, Ny, Nv), order='F')

    return energies, states, ESigmaZ, ESigmaZ, ESigmaZ


def diag_no_states(mz: object, num_of_band: object, t1: object, t2: object, t3: object, t4: object) -> object:
    energies = np.zeros((num_of_band, Nu * Ny * Nv), dtype=np.float64)
    states = np.zeros((dim, num_of_band, Nu * Ny * Nv), dtype=np.complex128)
    # ESigmaX = np.zeros((num_of_band, Nu*Ny*Nv), dtype=np.complex128)
    # ESigmaY = np.zeros((num_of_band, Nu*Ny*Nv), dtype=np.complex128)
    ESigmaZ = np.zeros((num_of_band, Nu * Ny * Nv), dtype=np.float64)
    hspin1, hspin2, hspin3, hspin4 = Hamiltonians_No_psi.Hspin(mz, t1, t2, t3, t4)
    ku, ky, kv = np.meshgrid(ku_grid, ky_grid, kv_grid, indexing='ij')
    ku = np.reshape(ku, (np.size(ku), 1), order='F')
    ky = np.reshape(ky, (np.size(ky), 1), order='F')
    kv = np.reshape(kv, (np.size(kv), 1), order='F')
    start_time = time.time()
    for ii in range(0, Nu * Ny * Nv):
        # if np.mod(ii, 5000) == 0:
        #     print(ii / 5000)
        hp1, hp4 = Hamiltonians_No_psi.Hp(ku[ii], ky[ii], kv[ii])
        H = sparse.kron(sigma_1, hp1 + hspin1) + sparse.kron(sigma_2, hspin2) \
            + sparse.kron(sigma_3, hspin3) + sparse.kron(sigma_4, hp4 + hspin4)
        if mz < -199:
            envl, envc = spla.eigsh(H, k=num_of_band, which='LM', sigma=-199)
        else:
            envl, envc = spla.eigsh(H, k=num_of_band, which='LM', sigma=-5)
        sort_res = np.argsort(envl)
        energies[0:num_of_band, ii] = envl[sort_res]
        states[:, 0:num_of_band, ii] = envc[:, sort_res]
        Psi_Up = states[0:int(dim / 2), 0:num_of_band, ii]
        Psi_Dn = states[int(dim / 2):dim, 0:num_of_band, ii]
        # ESigmaX[0:num_of_band, ii] = np.diag(np.dot(Psi_Up.conj().T, Psi_Dn)+np.dot(Psi_Dn.conj().T, Psi_Up))
        # ESigmaY[0:num_of_band, ii] = np.diag(1j*np.dot(Psi_Dn.conj().T, Psi_Up)-1j*np.dot(Psi_Up.conj().T, Psi_Dn))
        # ESigmaZ[0:num_of_band, ii] = np.diag(np.dot(Psi_Up.conj().T, Psi_Up)-np.dot(Psi_Dn.conj().T, Psi_Dn))
        ESigmaZ[0:num_of_band, ii] = np.real(np.diag(Psi_Up.conj().T @ Psi_Up - Psi_Dn.conj().T @ Psi_Dn))
        te = np.imag(np.diag(Psi_Up.conj().T @ Psi_Up - Psi_Dn.conj().T @ Psi_Dn))
        if len(te.nonzero()[0]) != 0:
            input("Wrong!")
    # print("Time total diag:", time.time() - start_time)
    energies = np.reshape(energies, (num_of_band, Nu, Ny, Nv), order='F')
    states = np.reshape(states, (dim, num_of_band, Nu, Ny, Nv), order='F')
    # ESigmaX = np.reshape(ESigmaX, (num_of_band, Nu, Ny, Nv), order='F')
    # ESigmaY = np.reshape(ESigmaY, (num_of_band, Nu, Ny, Nv), order='F')
    ESigmaZ = np.reshape(ESigmaZ, (num_of_band, Nu, Ny, Nv), order='F')

    return energies, ESigmaZ


def diagpar_TBM(mz, num_of_band):
    energies = np.zeros((num_of_band, Nu * Ny * Nv))
    ESigmaZ = np.zeros((num_of_band, Nu * Ny * Nv))
    ku, ky, kv = np.meshgrid(ku_grid, ky_grid, kv_grid, indexing='ij')
    ku = np.reshape(ku, (np.size(ku), 1), order='F')
    ky = np.reshape(ky, (np.size(ky), 1), order='F')
    kv = np.reshape(kv, (np.size(kv), 1), order='F')
    E0 = np.sqrt(4 * tSO * tSO * np.power(np.sin(ku * np.pi), 2)
                 + 4 * tSO * tSO * np.power(np.sin(kv * np.pi), 2)
                 + np.power(
        mz - 2 * t0 * np.cos(ku * np.pi) - 2 * t0 * np.cos(kv * np.pi) - 2 * ty * np.cos(ky * np.pi), 2))
    E1 = -E0
    energies[0, :] = E0[:, 0]
    energies[1, :] = E1[:, 0]
    a = mz - 2 * t0 * np.cos(ku * np.pi) - 2 * t0 * np.cos(kv * np.pi) - 2 * ty * np.cos(ky * np.pi)
    ESigmaZ[0, :] = (a[:, 0] - E0[:, 0]) / (2 * E0[:, 0])
    ESigmaZ[1, :] = (a[:, 0] - E1[:, 0]) / (2 * E1[:, 0])
    energies = np.reshape(energies, (num_of_band, Nu, Ny, Nv), order='F')
    ESigmaZ = np.reshape(ESigmaZ, (num_of_band, Nu, Ny, Nv), order='F')
    return energies, ESigmaZ


def diagparxy(mz, num_of_band, kz, psi):
    # kx_grid = np.linspace(-1+np.abs(kz), 1-np.abs(kz), (4-2*np.abs(kz))*50+1)
    # kx_grid = kx_grid_original[(kx_grid_original < 1-np.abs(kz)) & (kx_grid_original > -1+np.abs(kz))]
    kx_grid = np.linspace(-1 + np.abs(kz), 1 - np.abs(kz), gridN)
    Nx = kx_grid.size
    energies = np.zeros((num_of_band, Nx * Ny), dtype=np.complex128)
    states = np.zeros((dim, num_of_band, Nx * Ny), dtype=np.complex128)
    ESigmaZ = np.zeros((num_of_band, Nx * Ny), dtype=np.complex128)
    hspin1, hspin2, hspin3, hspin4 = Hamiltonians_No_psi.Hspin(mz, psi)
    kx, ky = np.meshgrid(kx_grid, ky_grid, indexing='ij')
    kx = np.reshape(kx, (np.size(kx), 1), order='F')
    ky = np.reshape(ky, (np.size(ky), 1), order='F')
    for ii in range(Nx * Ny):
        if np.mod(ii, 500) == 0:
            print(ii)
        ku = (kx[ii] + kz)
        kv = (kx[ii] - kz)
        hp1, hp4 = Hamiltonians_No_psi.Hp(ku, ky[ii], kv)
        H = sparse.kron(sigma_1, hp1 + hspin1) + sparse.kron(sigma_2, hspin2) \
            + sparse.kron(sigma_3, hspin3) + sparse.kron(sigma_4, hp4 + hspin4)
        try:
            envl, envc = spla.eigsh(H, k=num_of_band, which='LM', sigma=-5)
            sort_res = np.argsort(envl)
            energies[0:num_of_band, ii] = envl[sort_res]
            states[:, 0:num_of_band, ii] = envc[:, sort_res]
            Psi_Up = states[0:int(dim / 2), 0:num_of_band, ii]
            Psi_Dn = states[int(dim / 2):dim, 0:num_of_band, ii]
            ESigmaZ[0:num_of_band, ii] = np.diag(Psi_Up.conj().T @ Psi_Up - Psi_Dn.conj().T @ Psi_Dn)
        except ArpackNoConvergence:
            print('ku=', ku, 'ky=', ky[ii], 'kv=', kv, 'no convergence. Trying normal methods')
            energies[0:num_of_band, ii] = np.sort(scila.eigh(H.toarray())[0])[0:num_of_band]
    energies = np.reshape(energies, (num_of_band, Nx, Ny), order='F')
    states = np.reshape(states, (dim, num_of_band, Nx, Ny), order='F')
    ESigmaZ = np.reshape(ESigmaZ, (num_of_band, Nx, Ny), order='F')

    return energies, states, ESigmaZ
